package org.joone.engine;

import java.util.ArrayList;
import java.util.List;
import org.joone.engine.extenders.DeltaRuleExtender;
import org.joone.engine.extenders.GradientExtender;
import org.joone.engine.extenders.UpdateWeightExtender;

/* loaded from: input_file:org/joone/engine/ExtendableLearner.class */
public class ExtendableLearner extends AbstractLearner {
    protected List theDeltaRuleExtenders = new ArrayList();
    protected List theGradientExtenders = new ArrayList();
    protected UpdateWeightExtender theUpdateWeightExtender;

    @Override // org.joone.engine.Learner
    public final void requestBiasUpdate(double[] dArr) {
        preBiasUpdate(dArr);
        for (int i = 0; i < getLayer().getRows(); i++) {
            updateBias(i, getDelta(dArr, i));
        }
        postBiasUpdate(dArr);
    }

    @Override // org.joone.engine.Learner
    public final void requestWeightUpdate(double[] dArr, double[] dArr2) {
        preWeightUpdate(dArr, dArr2);
        for (int i = 0; i < getSynapse().getInputDimension(); i++) {
            for (int i2 = 0; i2 < getSynapse().getOutputDimension(); i2++) {
                updateWeight(i, i2, getDelta(dArr2, i, dArr, i2));
            }
        }
        postWeightUpdate(dArr, dArr2);
    }

    protected void updateBias(int i, double d) {
        this.theUpdateWeightExtender.updateBias(i, d);
    }

    protected void updateWeight(int i, int i2, double d) {
        this.theUpdateWeightExtender.updateWeight(i, i2, d);
    }

    protected double getDelta(double[] dArr, int i) {
        double defaultDelta = getDefaultDelta(dArr, i);
        for (int i2 = 0; i2 < this.theDeltaRuleExtenders.size(); i2++) {
            if (((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i2)).isEnabled()) {
                defaultDelta = ((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i2)).getDelta(dArr, i, defaultDelta);
            }
        }
        return defaultDelta;
    }

    public double getDefaultDelta(double[] dArr, int i) {
        return getLearningRate(i) * getGradientBias(dArr, i);
    }

    protected double getDelta(double[] dArr, int i, double[] dArr2, int i2) {
        double defaultDelta = getDefaultDelta(dArr, i, dArr2, i2);
        for (int i3 = 0; i3 < this.theDeltaRuleExtenders.size(); i3++) {
            if (((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i3)).isEnabled()) {
                defaultDelta = ((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i3)).getDelta(dArr, i, dArr2, i2, defaultDelta);
            }
        }
        return defaultDelta;
    }

    public double getDefaultDelta(double[] dArr, int i, double[] dArr2, int i2) {
        return getLearningRate(i, i2) * getGradientWeight(dArr, i, dArr2, i2);
    }

    protected double getLearningRate(int i) {
        return getMonitor().getLearningRate();
    }

    protected double getLearningRate(int i, int i2) {
        return getMonitor().getLearningRate();
    }

    public double getGradientBias(double[] dArr, int i) {
        double defaultGradientBias = getDefaultGradientBias(dArr, i);
        for (int i2 = 0; i2 < this.theGradientExtenders.size(); i2++) {
            if (((GradientExtender) this.theGradientExtenders.get(i2)).isEnabled()) {
                defaultGradientBias = ((GradientExtender) this.theGradientExtenders.get(i2)).getGradientBias(dArr, i, defaultGradientBias);
            }
        }
        return defaultGradientBias;
    }

    public double getDefaultGradientBias(double[] dArr, int i) {
        return dArr[i];
    }

    public double getGradientWeight(double[] dArr, int i, double[] dArr2, int i2) {
        double defaultGradientWeight = getDefaultGradientWeight(dArr, i, dArr2, i2);
        for (int i3 = 0; i3 < this.theGradientExtenders.size(); i3++) {
            if (((GradientExtender) this.theGradientExtenders.get(i3)).isEnabled()) {
                defaultGradientWeight = ((GradientExtender) this.theGradientExtenders.get(i3)).getGradientWeight(dArr, i, dArr2, i2, defaultGradientWeight);
            }
        }
        return defaultGradientWeight;
    }

    public double getDefaultGradientWeight(double[] dArr, int i, double[] dArr2, int i2) {
        return dArr[i] * dArr2[i2];
    }

    protected final void preBiasUpdate(double[] dArr) {
        preBiasUpdateImpl(dArr);
        if (this.theUpdateWeightExtender != null && this.theUpdateWeightExtender.isEnabled()) {
            this.theUpdateWeightExtender.preBiasUpdate(dArr);
        }
        for (int i = 0; i < this.theDeltaRuleExtenders.size(); i++) {
            if (((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i)).isEnabled()) {
                ((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i)).preBiasUpdate(dArr);
            }
        }
        for (int i2 = 0; i2 < this.theGradientExtenders.size(); i2++) {
            if (((GradientExtender) this.theGradientExtenders.get(i2)).isEnabled()) {
                ((GradientExtender) this.theGradientExtenders.get(i2)).preBiasUpdate(dArr);
            }
        }
    }

    protected void preBiasUpdateImpl(double[] dArr) {
    }

    protected final void preWeightUpdate(double[] dArr, double[] dArr2) {
        preWeightUpdateImpl(dArr, dArr2);
        if (this.theUpdateWeightExtender != null && this.theUpdateWeightExtender.isEnabled()) {
            this.theUpdateWeightExtender.preWeightUpdate(dArr2, dArr);
        }
        for (int i = 0; i < this.theDeltaRuleExtenders.size(); i++) {
            if (((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i)).isEnabled()) {
                ((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i)).preWeightUpdate(dArr2, dArr);
            }
        }
        for (int i2 = 0; i2 < this.theGradientExtenders.size(); i2++) {
            if (((GradientExtender) this.theGradientExtenders.get(i2)).isEnabled()) {
                ((GradientExtender) this.theGradientExtenders.get(i2)).preWeightUpdate(dArr2, dArr);
            }
        }
    }

    protected void preWeightUpdateImpl(double[] dArr, double[] dArr2) {
    }

    protected final void postBiasUpdate(double[] dArr) {
        for (int i = 0; i < this.theGradientExtenders.size(); i++) {
            if (((GradientExtender) this.theGradientExtenders.get(i)).isEnabled()) {
                ((GradientExtender) this.theGradientExtenders.get(i)).postBiasUpdate(dArr);
            }
        }
        for (int i2 = 0; i2 < this.theDeltaRuleExtenders.size(); i2++) {
            if (((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i2)).isEnabled()) {
                ((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i2)).postBiasUpdate(dArr);
            }
        }
        if (this.theUpdateWeightExtender != null && this.theUpdateWeightExtender.isEnabled()) {
            this.theUpdateWeightExtender.postBiasUpdate(dArr);
        }
        postBiasUpdateImpl(dArr);
    }

    protected void postBiasUpdateImpl(double[] dArr) {
    }

    protected final void postWeightUpdate(double[] dArr, double[] dArr2) {
        for (int i = 0; i < this.theGradientExtenders.size(); i++) {
            if (((GradientExtender) this.theGradientExtenders.get(i)).isEnabled()) {
                ((GradientExtender) this.theGradientExtenders.get(i)).postWeightUpdate(dArr2, dArr);
            }
        }
        for (int i2 = 0; i2 < this.theDeltaRuleExtenders.size(); i2++) {
            if (((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i2)).isEnabled()) {
                ((DeltaRuleExtender) this.theDeltaRuleExtenders.get(i2)).postWeightUpdate(dArr2, dArr);
            }
        }
        if (this.theUpdateWeightExtender != null && this.theUpdateWeightExtender.isEnabled()) {
            this.theUpdateWeightExtender.postWeightUpdate(dArr2, dArr);
        }
        postWeightUpdateImpl(dArr2, dArr2);
    }

    protected void postWeightUpdateImpl(double[] dArr, double[] dArr2) {
    }

    public void addDeltaRuleExtender(DeltaRuleExtender deltaRuleExtender) {
        this.theDeltaRuleExtenders.add(deltaRuleExtender);
        deltaRuleExtender.setLearner(this);
    }

    public void addGradientExtender(GradientExtender gradientExtender) {
        this.theGradientExtenders.add(gradientExtender);
        gradientExtender.setLearner(this);
    }

    public void setUpdateWeightExtender(UpdateWeightExtender updateWeightExtender) {
        this.theUpdateWeightExtender = updateWeightExtender;
        this.theUpdateWeightExtender.setLearner(this);
    }

    public UpdateWeightExtender getUpdateWeightExtender() {
        return this.theUpdateWeightExtender;
    }
}
